% GNF demo for clustering with the MNIST database
clear
MaxNeurons = 50; % Maximum number of neurons in each graph
Even = 1; % 0 Odd / 1 Even / 2 All
PCA = 1; % 0 No PCA / 1 PCA

% The following values of the parameters are those considered in the
% original GNG paper by Fritzke (1995)
Lambda=100;
Epochs=2;
EpsilonB=0.2;
EpsilonN=0.006;
Alpha=0.5;
AMax=50;
D=0.995;

% Generate data from MNIST database
load('MNISTTrain.mat','images');
AllSamples = images;
[NumFeatures,NumSamples] = size(AllSamples);
NumSteps = Epochs*NumSamples;
NumRowsImg = sqrt(NumFeatures);
NumColsImg = NumRowsImg;
ImageSize=0.028;

% Digits selection
load('MNISTTrainLabels.mat','labels');
labels = labels';
if Even==1,
    disp('EVEN DIGITS');
    NdxValidDigits = (0:2:8);
elseif Even==0,
    disp('ODD DIGITS');
    NdxValidDigits = (1:2:9);
else
    disp('ALL DIGITS');
    NdxValidDigits = (0:9);
end
NdxDigits = [];
for NdxDig=NdxValidDigits,    
    NdxDigits = [NdxDigits find(labels==NdxDig)];
end
Samples = AllSamples(:,NdxDigits);
[NumFeatures, NumSamples] = size(Samples);

if PCA,
    % Perform a global PCA
    disp('PCA');
    Dimension = 15;
    GlobalMean = mean(Samples')';
    CovGlobal = cov(Samples');
    [Uq, Lambdaq] = eigs(CovGlobal,Dimension,'LM');
    Lambdaq = diag(Lambdaq);
    Samples_zq = Uq'*(Samples-repmat(GlobalMean,1,size(Samples,2)));
    Samples = Samples_zq;
else
    disp('NO PCA');
end

% GNF Training
[Model] = TrainGNF(Samples,MaxNeurons,Lambda,EpsilonB,EpsilonN,Alpha,AMax,D,NumSteps);

if PCA,
    % Undo PCA
    Model.Means = Uq*Model.Means+repmat(GlobalMean,1,size(Model.Means,2));
end

% Plot the Model
[Handle] = PlotGNFImages(Model,NumRowsImg,NumColsImg);